import os
import pickle

from agent import SportsAgent
from density_model.local_outlier_factor import train_lof

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import datetime
import json
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from generic.data_util import read_args, load_config, ICEHOCKEY_ACTIONS


def train(args):
    mode = 'test'
    config, debug_mode, log_file_path = load_config(args)
    if debug_mode:
        debug_msg = 'debug_'
    else:
        debug_msg = ''
    rank_metric = config['general']['task']
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    sanity_check_msg = None

    if args.LEARN_MODE == 'location_ha':
        print('-' * 100, file=log_file, flush=True)
        print("*** Warning: Launching the sanity check. ***", file=log_file, flush=True)
        config['general']['model']['input_dim'] = len(ICEHOCKEY_ACTIONS) + 4
        sanity_check_msg = 'sanity_check_location_ha_'  # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
        debug_msg = sanity_check_msg + debug_msg
        print('-' * 100, file=log_file, flush=True)
    elif args.LEARN_MODE == 'normal':
        pass
    else:
        raise ValueError("Unknown learning mode {0}".format(args.LEARN_MODE))

    today = datetime.date.today()
    if args.CHECK_POINT is not None:
        date_label = args.CHECK_POINT
    else:
        date_label = today.strftime("%b-%d-%Y")

    agent = SportsAgent(config=config, log_file=log_file)

    lof_model = train_lof(agent=agent,
                          debug_mode=debug_mode,
                          sanity_check_msg=sanity_check_msg)

    lof_model_save_dir = "../save_model/lof/{4}saved_lof{0}_neighbors-{1}_metric-{2}_{3}".format(
        '_history' if agent.lof_apply_history else '',
        agent.lof_neighbors,
        agent.lof_metric,
        date_label,
        debug_msg
    )
    with open(lof_model_save_dir, 'wb') as save_file:
        pickle.dump(lof_model, save_file)


if __name__ == "__main__":
    args = read_args()
    if int(args.TRAIN_FLAG):
        train(args)
    # else:
    #     test(args)
